class MTAModel():
    def __init__(self, config_dict, data_meta_info):
        super().__init__(config_dict, data_meta_info)
        self.style_tplane = StyleUNet(in_size=512, out_size=256, in_dim=3, out_dim=96, activation=False)
        self.nerf_mlp = PointsDecoder(in_dim=32, out_dim=32, points_number=5023, points_k=8)
        self.up_renderer = StyleUNet(in_size=512, out_size=512, in_dim=32, out_dim=3)
        self.points_features = torch.nn.Parameter(torch.randn(1, 32, 5023), requires_grad=True)
        # ATTN
        self.query_style_tplane = torch.nn.Parameter(torch.randn(1, 1, 3, 32, 256, 256), requires_grad=True)
        self.attn_module = SoftTPTransformer(dim=32, qkv_bias=True)

    def forward_train(self, input_images, images, points, transforms):
        # render
        # set camera pose
        self.train_renderer.set_position(transform_matrix=transforms)
        # tri-planes
        tex_tplanes = self.style_tplane(input_images)
        if random.random() < 0.3:
            tex_tplanes = tex_tplanes[:, :1]
        tex_tplanes = self.attn_module(
            self.query_style_tplane, tex_tplanes
        )
        # render
        gen_coarse, gen_fine, params_dict = self.train_renderer.render(
            two_path=True, stratified_sampling=True, 
            # nerf_fn params
            points_position=points, points_features=self.points_features, tex_tplanes=tex_tplanes
        )
        gen_sr = self.up_renderer(gen_fine)
        gen_coarse, gen_fine = gen_coarse[:, :3], gen_fine[:, :3]
        # gather
        results = {
            'images':images, 'gen_coarse': gen_coarse, 
            'gen_fine': gen_fine, 'gen_sr': gen_sr, 
            'densities': params_dict['coarse']['densities'],
        }
        return results

    @torch.no_grad()
    def forward_inference(
            self, input_images, points, transforms, background, infos, **kwargs
        ):
        if not hasattr(self, 'texture_planes'):
            tex_tplanes = self.style_tplane(input_images)
            self.texture_planes = self.attn_module(
                self.query_style_tplane, tex_tplanes
            )
            print('Tri-Planes built.')
        # render
        # set camera pose
        self.val_renderer.set_position(transform_matrix=transforms)
        # render
        _, gen_fine, params_dict = self.val_renderer.render(
            two_path=True, stratified_sampling=False, background=self.background, 
            # nerf_fn params
            points_position=points, points_features=self.points_features, tex_tplanes=self.texture_planes
        )
        gen_sr = self.up_renderer(gen_fine)
        gen_fine = gen_fine[:, :3]
        # gather
        results = {
            'input_images':input_images, 'gen_fine': gen_fine, 'gen_sr': gen_sr, 'depth': params_dict['depth']
        }
        return results

    def calc_metrics(self, images, gen_coarse, gen_fine, gen_sr, densities, bbox, **kwargs):
        loss_fn = torch.nn.functional.l1_loss
        gt_small, gt_large = self.resize(images, gen_coarse), images
        pec_loss_0 = self.percep_loss(gen_coarse, gt_small)
        pec_loss_1 = self.percep_loss(gen_fine, gt_small)
        pec_loss_2 = self.percep_loss(gen_sr, gt_large)
        img_loss_0 = loss_fn(gen_coarse, gt_small)
        img_loss_1 = loss_fn(gen_fine, gt_small)
        img_loss_2 = loss_fn(gen_sr, gt_large)
        pec_loss = (pec_loss_0 + pec_loss_1 + pec_loss_2) / 3 * 1e-2
        img_loss = (img_loss_0 + img_loss_1 + img_loss_2)
        densities_loss = torch.norm(densities, p=2) * 1e-5
        metrics = {
            'percep_loss': pec_loss, 'img_loss': img_loss, 
            'density_loss': densities_loss, 
        }
        # print(metrics)
        psnr = -10.0 * torch.log10(torch.nn.functional.mse_loss(gen_sr, images).detach())
        return metrics, psnr


class PointsDecoder(torch.nn.Module):
    def __init__(self, in_dim, out_dim, points_number, points_k=4):
        super().__init__()
        # encoder
        n_harmonic_dir = 4
        embedding_dim = n_harmonic_dir * 2 * 3 + 3
        self.pos_encoder = HarmonicEmbedding(n_harmonic_dir)
        # model
        self.points_k = points_k
        self.points_querier = DynamicPointsQuerier(in_dim, points_number)
        # nerf decoder
        self.feature_layers = torch.nn.Sequential(
            torch.nn.Linear(32+32, 128, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128, bias=True),
        )
        self.density_layers = torch.nn.Sequential(
            torch.nn.Linear(128, 1, bias=True),
            torch.nn.Softplus(beta=10.0)
        )
        self.rgb_layers = torch.nn.Sequential(
            torch.nn.Linear(128+embedding_dim, 64, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(64, out_dim, bias=True),
        )
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, coordinates, directions, points_position, points_features, tex_tplanes):
        N, M, C = coordinates.shape
        coordinates = coordinates.clamp(-0.999, 0.999)
        # query feature
        sampled_tex_features = sample_from_planes(coordinates, tex_tplanes)
        # sampled_nerf_features = plane_features
        sampled_points_features, distances = self.points_querier(
            coordinates, points_position, points_features, K=self.points_k
        )
        sampled_nerf_features = torch.cat([sampled_tex_features, sampled_points_features], dim=-1)
        # nerf forward
        feat = self.feature_layers(sampled_nerf_features)
        raw_densities = self.density_layers(feat)
        densities = 1 - (-raw_densities).exp()
        # directions
        rays_embedding = torch.nn.functional.normalize(directions, dim=-1)
        rays_embedding = self.pos_encoder(rays_embedding)
        rgb = self.rgb_layers(torch.cat([feat, rays_embedding], dim=-1))
        rgb[..., :3] = self.sigmoid(rgb[..., :3])
        rgb[..., :3] = rgb[..., :3]*(1 + 2*0.001) - 0.001 
        return densities, rgb, {'densities': densities, 'distances': distances}


class DynamicPointsQuerier(torch.nn.Module):
    def __init__(self, in_dim, points_number, outputs_dim=32, bin_size=1/64):
        super().__init__()
        # self.bin_size = bin_size
        self.points_dim = in_dim
        self.outputs_dim = outputs_dim
        self.points_number = points_number
        n_harmonic_functions = 6
        embedding_dim = n_harmonic_functions * 2 * 3 + 3
        self.pos_encoder = HarmonicEmbedding(n_harmonic_functions)
        # params
        # self.points_feature = torch.nn.Embedding(points_number, points_dim)
        self.point_layers = torch.nn.Sequential(
            torch.nn.Linear(in_dim+embedding_dim, outputs_dim*2, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(outputs_dim*2, outputs_dim, bias=True),
        )

    def forward(self, coordinates, points_position, points_features, K):
        # query feature
        dist, idx, nn = knn_points(
            coordinates.float(), points_position.float(), K=K, return_nn=True
        )
        points_features = torch.stack([
            torch.nn.functional.embedding(idx[bidx], points_features[bidx]) 
            for bidx in range(idx.shape[0])
        ])
        points_relative_pos = coordinates[:, :, None] - nn
        points_relative_pos_embed = self.pos_encoder(points_relative_pos)
        # mlp processing
        points_features = torch.cat([points_features, points_relative_pos_embed], dim=-1)
        points_features = self.point_layers(points_features)
        points_weights = 1 / dist
        points_weights = points_weights / points_weights.sum(dim=1, keepdim=True)
        sample_points_features = (points_features * points_weights.unsqueeze(-1)).sum(dim=2)
        return sample_points_features, dist


class SoftTPTransformer(nn.Module):
    def __init__(self, dim, qkv_bias=False):
        super().__init__()
        self.scale = nn.Parameter(torch.tensor([dim**-0.5]), requires_grad=False)

        self.attend = nn.Softmax(dim=-1)
        self.to_q_ = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim, bias=qkv_bias),
            nn.GELU(),
            nn.Linear(dim, dim, bias=qkv_bias),
            nn.GELU(),
            nn.Linear(dim, dim, bias=qkv_bias),
        )
        self.to_k_ = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim, bias=qkv_bias),
            nn.GELU(),
            nn.Linear(dim, dim, bias=qkv_bias),
            nn.GELU(),
            nn.Linear(dim, dim, bias=qkv_bias),
        )
    
    def forward(self, query, key):
        B, N, T, C, H, W = key.shape
        query, key = map(lambda t: rearrange(t, 'b n t c h w -> (b t h w) n c'), [query, key])
        value = key
        assert T == 3
        q_ = self.to_q_(query); k_ = self.to_k_(key)
        dots = torch.matmul(q_, k_.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        value = torch.matmul(attn, value)
        value = rearrange(value, '(b t h w) n c -> b n t c h w', b=B, t=T, h=H, w=W)
        assert value.shape[1] == 1
        return value[:, 0]